# #####################################
# #                                   #
# #    NOT FULLY FUNCTIONAL CODE      #
# #                                   #
# #####################################
#
# This is not fully functional python code.  These are some code snippets from
# the pipeline we used to convert the FARSITE data into the tensorflow
# TFRecordio files that were used in our training and evaluation pipelines.
# This code should help fill in details for how to take similar numpy files and
# convert them into data that you can apply to the models.
#
# We've supplied two Apache Beam DoFns that performed the bulk of the work, as
# well as some of the Beam pipeline code that calls those functions.  Unless
# you're working at generating a large amount of data, you likely don't have to
# worry about using Apache Beam, and instead can strip out the logic into more
# typical python code that simply saves numpy arrays.
#
# The bulk of the real work is done in ProcessInputFile.  Understanding what is
# done in that function is most of what you'd need to replicate this code in a
# non-beam setting.  We've included functions for loading up all the data files
# and code for applying normalization as well.

#
# run(cfg) was called to create the data:
#

# cfg is just a map-like object (A ConfigDict more specifically) that stores
# key-value string pairs.  It's used to store locations of file names mostly.
# Treat it like a typical python dictionary.

def run(cfg: ConfigDict):
  """Runs the code for the associated colab."""
  # Perform tasks for starting up the pipeline and initializing Beam.
  # ...

  # Find all the burnDuration files in the given directory.
  input_glob = os.path.join(
      cfg.input_base_dir,
      f'{cfg.burn_filename_prefix}_*_burn.npy')
  print('Looking for all files that match: {}'.format(input_glob))
  input_files = Glob(input_glob)
  print('The set of input files found: ', input_files)

  # Get the runner and run the pipeline.
  pipeline_results = runner.FlumeRunner().run(DataPipeline(cfg, input_files))

#
# This code generates the Beam pipeline.
#
def _load_norms(input_filename: str) -> Dict[str, Tuple[int, float, float]]:
  """Loads norms from input_filename and returns them."""
  # This function just needs to read in the text file that contains a python
  # dictionary, create the dictionary from that (e.g., with eval) and return
  # that dictionary.
  # ...


class DataPipeline:
  """Defines the beam pipeline to generate the data."""

  def __init__(self, cfg: ConfigDict, input_files: List[str]):
    self.cfg = cfg
    self.input_files = input_files

  def __call__(self, root):
    # Create the input files.
    input_files = (
        root
        # PCollection[str], unique rows
        | 'CreateInputFiles' >> beam.Create(self.input_files)
    )

    # Get the norm coefficients.
    norms = _load_norms(self.cfg.normalize_input_filename)
    norms = (
        root
        # PCollection[Dict[str, Tuple[int, float, float]]]
        | beam.Create([norms])
    )

    # Generate the results.
    results = (
        input_files
        # PCollection[tf.train.Example]
        | 'ProcessInputFiles' >> beam.ParDo(
            ProcessInputFile(self.cfg), norms=beam.pvalue.AsSingleton(norms))
    )

    # Results is a PCollection of tf.train.Examples.  You can save them into
    # a TF Recordio, or transform them into any kind of data (e.g., numpy
    # arrays) that you would prefer to use in whatever code you have to use the
    # Tensorflow models we've provided.


# Some constants in common.py
# Variables to track which channels hold which types of data.
CHANNEL_TREE = 0
CHANNEL_LAST_LABEL = 1
CHANNEL_ASH = 2
CHANNEL_WIND_EAST = 3
CHANNEL_WIND_NORTH = 4
CHANNEL_MOISTURE_1 = 5
CHANNEL_MOISTURE_10 = 6
CHANNEL_MOISTURE_100 = 7
CHANNEL_MOISTURE_HER = 8
CHANNEL_MOISTURE_WOO = 9
CHANNEL_COVER = 10
CHANNEL_HEIGHT = 11
CHANNEL_BASE = 12
CHANNEL_DENSITY = 13
CHANNEL_SLOPE_EAST = 14
CHANNEL_SLOPE_NORTH = 15
CHANNEL_FUEL = 16

#
# Utility functions to load files.
#
def _trim_boundary(cfg: ConfigDict,
                   numpy_data: np.ndarray) -> np.ndarray:
  """Removes pixels from the numpy data as specified by cfg.trim_pixels."""
  if cfg.trim_pixels:
    if len(numpy_data.shape) == 2:
      assert numpy_data.shape[0] == numpy_data.shape[1], 'assuming square'
      numpy_data = numpy_data[1:-1, 1:-1]
    elif len(numpy_data.shape) == 3:
      assert numpy_data.shape[1] == numpy_data.shape[2], 'assuming square'
      numpy_data = numpy_data[::, 1:-1, 1:-1]
    else:
      raise ValueError('Expecting 2d or 3d data.')
    print('  Reshaping to: ', numpy_data.shape)
  return numpy_data


def _get_one_file(cfg: ConfigDict,
                  which_file: str,
                  index: int):
  """Returns a single input data file."""
  file_pattern = os.path.join(
      cfg.input_base_dir,
      '{}_{:03d}_{}.npy'.format(cfg.burn_filename_prefix, index, which_file))
  try:
    print('Trying to load file: ', file_pattern)
    with Open(file_pattern, 'rb') as f:
      numpy_data = np.load(f, allow_pickle=True).astype('float64')
    print('Done loading file: ', file_pattern)

  except pickle.UnpicklingError as e:
    # pylint: disable=raise-missing-from
    raise ValueError(
        f'There was an error unpickling the file {file_pattern}: {e}')
  print('Adding {} index {} with shape {} from file {}'.format(
      which_file, index, numpy_data.shape, file_pattern))
  numpy_data = _trim_boundary(cfg, numpy_data)
  return numpy_data


def _get_data(cfg: ConfigDict,
              burn_filename: Text) -> Tuple[int, List[Any]]:
  """Loads the data needed to process a single movie.

  Args:
    cfg: The configuration.
    burn_filename: The name of the file to load the burn tree data from.

  Returns:
    A tuple containing the datapoint's id, and the list of channels.
  """
  # Find the corresponding 'burn' file for this 'fuel' file.
  # pylint: disable=anomalous-backslash-in-string
  search_results = re.search(
      f'{cfg.burn_filename_prefix}_(\d+)_burn', burn_filename)
  if not search_results:
    raise ValueError('Could not find burn file.')

  movie_idx = int(search_results.group(1))
  lit_filename = os.path.join(
      cfg.input_base_dir,
      '{}_{:03d}_burn.npy'.format(cfg.burn_filename_prefix, movie_idx))

  # Load the two files and the additional channels.
  try:
    # The lit file indicates how much of the fuel in a location has burnt away.
    print('Loading lit file: {}'.format(lit_filename))
    with Open(lit_filename, 'rb') as f:
      lit = np.load(f, allow_pickle=True).astype('float64')  # shape = (thw)
      lit = _trim_boundary(cfg, lit)
    if len(lit.shape) != 3:
      raise ValueError('Expecting lit file to have rank 3.')
    if lit.shape[1] != lit.shape[2]:
      raise ValueError('Expecting the spatial dims of lit to be square.')

    # The burn_duration value specifies how much fuel is located in a given
    # position.  All fuel starts at a value of 1.0.
    burn_duration = np.ones(shape=lit.shape[1:], dtype=np.float32)  # (hw)

    wind_east = _get_one_file(cfg, 'wind_east', movie_idx)
    wind_north = _get_one_file(cfg, 'wind_north', movie_idx)
    moisture_1 = _get_one_file(cfg, 'moisture_1_hour', movie_idx)
    moisture_10 = _get_one_file(cfg, 'moisture_10_hour', movie_idx)
    moisture_100 = _get_one_file(cfg, 'moisture_100_hour', movie_idx)
    moisture_her = _get_one_file(cfg, 'moisture_live_herbaceous', movie_idx)
    moisture_woo = _get_one_file(cfg, 'moisture_live_woody', movie_idx)
    cover = _get_one_file(cfg, 'cover', movie_idx)
    height = _get_one_file(cfg, 'height', movie_idx)
    base = _get_one_file(cfg, 'base', movie_idx)
    density = _get_one_file(cfg, 'density', movie_idx)
    slope_east = _get_one_file(cfg, 'slope_east', movie_idx)
    slope_north = _get_one_file(cfg, 'slope_north', movie_idx)
    fuel = _get_one_file(cfg, 'fuel', movie_idx)

  except FileNotFoundError as e:
    # pylint: disable=raise-missing-from
    raise ValueError(f'Saw error on movie_idx {movie_idx}: {e}')

  except ValueError as e:
    # pylint: disable=raise-missing-from
    raise ValueError(f'Saw ValueError on movie_idx {movie_idx}: {e}')

  except pickle.UnpicklingError as e:
    # pylint: disable=raise-missing-from
    raise ValueError(
        f'There was an error unpickling lit filename: {lit_filename}: {e}')

  return (movie_idx,
          [burn_duration, lit, wind_east, wind_north, moisture_1,
           moisture_10, moisture_100, moisture_her, moisture_woo,
           cover, height, base, density, slope_east, slope_north, fuel])


#
# This is the code where all the real work happens.  Each set of input files
# generated from a single run of FARSITE are loaded in a single call to this
# DoFn, resulting in the data that can be used by our models.
#
def _add_frame_to_results(
    frame: np.ndarray, results: np.ndarray, height: int, width: int) -> (
        np.ndarray):
  """Returns the given frame concatenated to the results."""
  reshaped_frame = np.reshape(frame, (1, height, width, 1))
  return np.concatenate((results, reshaped_frame), 3)


def _norm_data_by_mean_var(data: Any, mean: float, var: float) -> Any:
  """Normalizes data by removing the mean and dividing out the variance."""
  if mean == 0.0 and var == 0.0:
    return data
  if var == 0.0:
    return data - mean
  std = math.sqrt(var)
  result = (data - mean) / std
  return result


class ProcessInputFile(beam.DoFn):
  """Processes a single input file and generates its output file.

  PCollection[str] -> PCollection[tf.train.Example]
  """

  def __init__(self, cfg: ConfigDict):
    self.cfg = cfg

    # Counters.
    self._file_removed_short_sequences = beam.metrics.Metrics.counter(
        self.__class__, 'ProcessInputFile.file_removed_short_sequences')
    self._file_kept = beam.metrics.Metrics.counter(
        self.__class__, 'ProcessInputFile.file_kept')
    self._file_testing = beam.metrics.Metrics.counter(
        self.__class__, 'ProcessInputFile.file_testing')
    self._file_training = beam.metrics.Metrics.counter(
        self.__class__, 'ProcessInputFile.file_training')
    self._data_temporal = beam.metrics.Metrics.counter(
        self.__class__, 'ProcessInputFile.data_temporal')
    self._data_static = beam.metrics.Metrics.counter(
        self.__class__, 'ProcessInputFile.data_static')
    self._file_error = beam.metrics.Metrics.counter(
        self.__class__, 'ProcessInputFile.ERROR_file_skipped')
    self._no_growth = beam.metrics.Metrics.counter(
        self.__class__, 'ProcessInputFile.WARNING_no_growth')
    self._broken_input = beam.metrics.Metrics.counter(
        self.__class__, 'ProcessInputFile.ERROR_BROKEN_INPUT')

  def _normalize(self,
                 norms: Dict[str, Tuple[int, float, float]],
                 data: Tuple[int, List[Any]]) -> (Tuple[int, List[Any]]):
    """Normalizes the data."""
    # Break the data up into pieces.
    (movie_idx,
     [burn_duration, lit, wind_east, wind_north, moisture_1, moisture_10,
      moisture_100, moisture_her, moisture_woo, cover, height_channel,
      base, density, slope_east, slope_north, fuel]) = data

    # burn_duration, lit and fuel channels are not normalized, so they remain
    # unchanged.

    # Normalize wind channels.
    wind_east = _norm_data_by_mean_var(
        wind_east, norms['wind_east'][1], norms['wind_east'][2])
    wind_north = _norm_data_by_mean_var(
        wind_north, norms['wind_north'][1], norms['wind_north'][2])

    # Normalize moisture channels.
    moisture_1 = _norm_data_by_mean_var(
        moisture_1, norms['moisture_1'][1],
        norms['moisture_1'][2])
    moisture_10 = _norm_data_by_mean_var(
        moisture_10, norms['moisture_10'][1],
        norms['moisture_10'][2])
    moisture_100 = _norm_data_by_mean_var(
        moisture_100, norms['moisture_100'][1],
        norms['moisture_100'][2])
    moisture_her = _norm_data_by_mean_var(
        moisture_her, norms['moisture_her'][1],
        norms['moisture_her'][2])
    moisture_woo = _norm_data_by_mean_var(
        moisture_woo, norms['moisture_woo'][1],
        norms['moisture_woo'][2])

    # Normalize slope channels.
    slope_east = _norm_data_by_mean_var(
        slope_east, norms['slope_east'][1], norms['slope_east'][2])
    slope_north = _norm_data_by_mean_var(
        slope_north, norms['slope_north'][1], norms['slope_north'][2])

    # Normalize cover, height, base and density.
    cover = _norm_data_by_mean_var(
        cover, norms['cover'][1], norms['cover'][2])
    height_channel = _norm_data_by_mean_var(
        height_channel, norms['height'][1], norms['height'][2])
    base = _norm_data_by_mean_var(
        base, norms['base'][1], norms['base'][2])
    density = _norm_data_by_mean_var(
        density, norms['density'][1], norms['density'][2])

    return (movie_idx,
            [burn_duration, lit, wind_east, wind_north, moisture_1, moisture_10,
             moisture_100, moisture_her, moisture_woo, cover, height_channel,
             base, density, slope_east, slope_north, fuel])


  def process(self, burn_filename: str,
              norms: Dict[str, Tuple[int, float, float]]):
    print('Working on file: ', burn_filename)
    # Get the data for this filename.
    try:
      (movie_idx,
       [burn_duration, lit, wind_east, wind_north, moisture_1, moisture_10,
        moisture_100, moisture_her, moisture_woo, cover, height_channel,
        base, density, slope_east, slope_north, fuel]) = self._normalize(
            norms, _get_data(self.cfg, burn_filename))
    except ValueError as e:
      print(f'ERROR detected, skipping burn_filename {burn_filename}: {e}')
      self._broken_input.inc()
      return

    # Check to see if the data is data we want to keep.
    if lit.shape[0] < 20:
      print('Throwing away files for {} (too little data)'.format(
          burn_filename))
      self._file_removed_short_sequences.inc()
      return

    self._file_kept.inc()

    # Update the fuel so that it has values that range between 0 and ~40 (so it
    # can be used to feed a one-hot-encoding layer).
    fuel = FUEL_UPDATER_FUNC(fuel)

    # Go through all the time steps.
    num_time_steps = lit.shape[0]
    height = lit.shape[1]
    width = lit.shape[2]
    input_results = []
    label_results = []
    previous_lit_delta = np.zeros((height, width))
    lit_delta = np.zeros((height, width))

    # Go through all the time points in lit, and emit a single data point for
    # each of them.  The generated input_frame for time step i will be the
    # state of the tree and scar channels before the fire front in lit[i]
    # happened, so lit[i] should be the label.
    #
    # On the first iteration, the vegetation channel should be all ones, and
    # the scar channel should be all zeros, the previous_front channel should
    # be all zeros, and the previous_time_front channel should be all zeros.
    #
    # On the nth iteration, the vegetation channel should be equal to 1.0 -
    # lit[i-1], as the value inside lit is the cumulative amount of burn up
    # until time frame i.
    last_lit_sum = -1.0
    growth_start_index = -1
    for i in range(num_time_steps):
      # Create the label for this data point.  The label should be how much
      # tree mass burned down in this time step.
      if i == 0:
        lit_delta = lit[i]
      else:
        lit_delta = lit[i] - lit[i-1]
      label_frame = np.reshape(lit_delta, newshape=(1, height, width, 1))

      # Check to see if the fire has started growing yet.
      curr_lit_sum = np.sum(lit[i])
      if (last_lit_sum > 0.0 and curr_lit_sum != last_lit_sum and
          growth_start_index == -1):
        # This is the first round growth started.
        growth_start_index = i
      last_lit_sum = curr_lit_sum

      # Channel 0 = tree frame, aka, vegetation channel.
      if i == 0:
        tree_frame = burn_duration
      else:
        tree_frame = 1.0 - lit[i-1]
      input_frame = np.reshape(tree_frame, newshape=(1, height, width, 1))

      # Channel 1 = previous fire front frame.
      assert input_frame.shape[3] == common.CHANNEL_LAST_LABEL
      input_frame = _add_frame_to_results(
          previous_lit_delta, input_frame, height, width)
      previous_lit_delta = lit_delta

      # Channel 2 = ash frame
      if i == 0:
        ash_frame = np.zeros((height, width))
      else:
        ash_frame = lit[i-1]
      assert input_frame.shape[3] == common.CHANNEL_ASH
      input_frame = _add_frame_to_results(
          ash_frame, input_frame, height, width)

      # Channel 3 = wind_east
      assert input_frame.shape[3] == common.CHANNEL_WIND_EAST
      input_frame = _add_frame_to_results(
          wind_east[i:i+1, ::, ::], input_frame, height, width)

      # Channel 4 = wind_north
      assert input_frame.shape[3] == common.CHANNEL_WIND_NORTH
      input_frame = _add_frame_to_results(
          wind_north[i:i+1, ::, ::], input_frame, height, width)

      # Channel 5 = moisture_1
      assert input_frame.shape[3] == common.CHANNEL_MOISTURE_1
      input_frame = _add_frame_to_results(
          moisture_1, input_frame, height, width)

      # Channel 6 = moisture_10
      assert input_frame.shape[3] == common.CHANNEL_MOISTURE_10
      input_frame = _add_frame_to_results(
          moisture_10, input_frame, height, width)

      # Channel 7 = moisture_100
      assert input_frame.shape[3] == common.CHANNEL_MOISTURE_100
      input_frame = _add_frame_to_results(
          moisture_100, input_frame, height, width)

      # Channel 8 = moisture_her
      assert input_frame.shape[3] == common.CHANNEL_MOISTURE_HER
      input_frame = _add_frame_to_results(
          moisture_her, input_frame, height, width)

      # Channel 9 = moisture_woo
      assert input_frame.shape[3] == common.CHANNEL_MOISTURE_WOO
      input_frame = _add_frame_to_results(
          moisture_woo, input_frame, height, width)

      # Channel 10 = cover
      assert input_frame.shape[3] == common.CHANNEL_COVER
      input_frame = _add_frame_to_results(
          cover, input_frame, height, width)

      # Channel 11 = height
      assert input_frame.shape[3] == common.CHANNEL_HEIGHT
      input_frame = _add_frame_to_results(
          height_channel, input_frame, height, width)

      # Channel 12 = base
      assert input_frame.shape[3] == common.CHANNEL_BASE
      input_frame = _add_frame_to_results(
          base, input_frame, height, width)

      # Channel 13 = density
      assert input_frame.shape[3] == common.CHANNEL_DENSITY
      input_frame = _add_frame_to_results(
          density, input_frame, height, width)

      # Channel 14 = slope_east
      assert input_frame.shape[3] == common.CHANNEL_SLOPE_EAST
      input_frame = _add_frame_to_results(
          slope_east, input_frame, height, width)

      # Channel 15 = slope_north
      assert input_frame.shape[3] == common.CHANNEL_SLOPE_NORTH
      input_frame = _add_frame_to_results(
          slope_north, input_frame, height, width)

      # Channel 16 = fuel
      assert input_frame.shape[3] == common.CHANNEL_FUEL
      input_frame = _add_frame_to_results(
          fuel, input_frame, height, width)

      # Emit the input and label, unless this is the first data point (which
      # has a label that's impossible to predict).
      if growth_start_index != -1:
        input_results.append(input_frame)
        label_results.append(label_frame)
    # end for
    print('Growth for the time series started on step: ', growth_start_index)

    if growth_start_index < 0:
      print('This fire sequence did not grow, skipping it: ', movie_idx)
      self._no_growth.inc()
      return

    # Check for errors.
    if input_results is None or label_results is None:
      self._file_error.inc()
      print('ERROR: input_results and label_results should not be None.')
      return

    # Discard the first step to ensure that all first steps in the result have a
    # previous fire front that is non-zero.
    input_results = input_results[1:]
    label_results = label_results[1:]
    assert len(input_results) == len(label_results)
    print(f'There were {len(input_results)} images in the input result.')

    # Create a single tensor containing the entire time sequences.
    input_full = np.vstack(input_results)
    label_full = np.vstack(label_results)

    # Combine the input and label results into a single list to save.
    results = [input_full, label_full]
    assert results[0] is not None
    assert results[1] is not None

    # At this point, input_full and label_full are numpy arrays that contain
    # the entire fire sequence.  From here, we had functions that would yield
    # numpy arrays that were the correct shape to feed into the models.  The
    # EPD model would take in single time points and the EPD-ConvLSTM model
    # would take in sequences of 8 time points.  We then converted those numpy
    # arrays into tf.train.Examples that a Tensorflow Dataset could load in for
    # our training and inference pipelines.  Here, you could just save the
    # numpy arrays, and then in the code you have that uses our Tensorflow
    # models, you would load in the saved numpy arrays, split them up
    # accordingly into data points for the model, and feed them into the model.
    # But below this comment block is the code that will save the TF Record
    # IOs.

    # Create the static data points.
    for instance in fire_seq_convertor.generate_time_instances(
        results, self.cfg.tree_noise_mean, self.cfg.tree_noise_std,
        skip_initial_rows=self.cfg.skip_initial_rows):
      self._data_static.inc()
      yield fire_seq_convertor.get_tf_example(
          instance, set_label, movie_idx)

    # Create the temporal data points.
    for sequence in fire_seq_convertor.generate_time_sequences(
        results, self.cfg.sequence_length, self.cfg.tree_noise_mean,
        self.cfg.tree_noise_std, self.cfg.skip_initial_rows):
      self._data_temporal.inc()
      yield fire_seq_convertor.get_tf_example(sequence, set_label, movie_idx)


# Utility functions in the fire_seq_convertor module needed to split numpy
# arrays up into tf.train.Examples.

def generate_rows(data: List[Any], noise_mean: float, noise_std: float):
  """Yields individual rows out of arrays in data[0] and data[1].

  Args:
    data: The data.
    noise_mean: Mean of normal noise generator.
    noise_std: Standard deviation of normal noise generator.

  Yields:
    The input and label frame with shapes (1, h, w, c_input) and (1, h, w,
    c_label).
  """
  for i in range(data[0].shape[0]):
    input_frame = data[0][i:i+1, ::, ::, ::]  # (1, h, w, c)
    if noise_std != 0.0:
      height = input_frame.shape[1]
      width = input_frame.shape[2]
      num_channels = input_frame.shape[3]
      input_frame_split = tf.split(input_frame, num_channels, axis=3)
      noise_frame = tf.random.normal(
          shape=(1, height, width, 1),
          mean=noise_mean,
          stddev=noise_std,
          dtype=tf.dtypes.float64)
      noise_frame = tf.clip_by_value(noise_frame, 0.0, 100000.0)
      input_frame_split[0] = tf.math.multiply(
          input_frame_split[0], noise_frame)
      input_frame = tf.concat(input_frame_split, 3)
    label_frame = data[1][i:i+1, ::, ::, ::]
    yield (input_frame, label_frame)


def generate_time_sequences(data,
                            time_length: int,
                            noise_mean: float,
                            noise_std: float,
                            skip_initial_rows: int = -1):
  """Generates input/label pairs with a time dimension.

  The data is assumed to be a list of two numpy arrays.  The first element is
  the input data of shape (t, h, w, c_input).  The second element is the label
  data of shape (t, h, w, c_label).  I.e., all dimensions are the same, except
  the channel count.  NOTE: t is the length of an entire fire sequence, not the
  amount of time in a single temporal data point.

  See cfg.training.tree_noise_mean for description of noise.

  Args:
    data: The data.
    time_length: How many time steps should be in the returned data points.
    noise_mean: The mean parameter to use when adding noise.
    noise_std: The std parameter to use when adding noise.
    skip_initial_rows: The number of rows in the data input to skip.

  Yields:
    A 2-tuple containing
      [0] the input data sequence (t, h, w, c_input)
      [1] the label data sequence (t, h, w, c_label)
  """
  # data has shape 5 (s, t, h, w, c).  s is the number of samples, and then t is
  # the number of time points per sample.  To generate rows, we need to combine
  # the s and t channels together.
  assert len(data) == 2
  assert len(data[0].shape) == 4

  generator = generate_rows(data, noise_mean, noise_std)

  time_steps = data[0].shape[0]
  # Build up time sequences from the input data.
  curr_row = 0
  while curr_row < time_steps:
    # Skip the initial rows.
    if curr_row < skip_initial_rows:
      curr_row += 1
      continue

    # Stop if we can't make a full sequence from this frame onward.
    if curr_row + time_length >= time_steps:
      break
    input_frames = []
    label_frames = []
    for _ in range(time_length):
      next_row = next(generator)
      input_frames.append(next_row[0])
      label_frames.append(next_row[1])
      curr_row += 1

    input_sequence = np.concatenate(input_frames, axis=0)
    label_sequence = np.concatenate(label_frames, axis=0)
    yield (input_sequence, label_sequence)


def generate_time_instances(input_data: List[Any],
                            noise_mean: float,
                            noise_std: float,
                            skip_initial_rows: int = -1):
  """Like generate_time_sequences, but returns one time point. (h, w, c)."""
  generator = generate_rows(input_data, noise_mean, noise_std)

  # Return single data points without time.
  row_count = 0
  for row in generator:
    # Skip the initial rows.
    if row_count < skip_initial_rows:
      row_count += 1
    else:
      assert row[1].shape[0] == 1
      input_row = tf.reshape(row[0], shape=(row[0].shape)[1:])
      label_row = tf.reshape(row[1], shape=(row[1].shape)[1:])
      yield (input_row.numpy(), label_row.numpy())
